{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "from graph import Graph\r\n",
    "from pass_manager import pass_m\r\n",
    "import cv2\r\n",
    "import numpy as np\r\n",
    "import array"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# load onnx mode"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "inp_dim = [1, 1, 28, 28]\r\n",
    "graph = Graph('mnist', './data/mnist-8.onnx', inp_dim, debug=1)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "tensor: Parameter6[8] (8,1,1) producer:  ['_param'] \tconsumer:  ['Plus30']\n",
      "tensor: Input3[784] (1,1,28,28) producer:  [] \tconsumer:  ['Convolution28']\n",
      "tensor: Parameter194[10] (1,10) producer:  ['_param'] \tconsumer:  ['Plus214']\n",
      "tensor: ReLU114_Output_0[1] () producer:  ['ReLU114'] \tconsumer:  ['Pooling160']\n",
      "tensor: Plus214_Output_0[1] () producer:  ['Plus214'] \tconsumer:  []\n",
      "tensor: Plus30_Output_0[1] () producer:  ['Plus30'] \tconsumer:  ['ReLU32']\n",
      "tensor: Times212_Output_0[1] () producer:  ['Times212'] \tconsumer:  ['Plus214']\n",
      "tensor: Parameter5[200] (8,1,5,5) producer:  ['_param'] \tconsumer:  ['Convolution28']\n",
      "tensor: Pooling160_Output_0_reshape0_shape[2] (2) producer:  ['_param'] \tconsumer:  ['Times212_reshape0']\n",
      "tensor: Parameter88[16] (16,1,1) producer:  ['_param'] \tconsumer:  ['Plus112']\n",
      "tensor: Pooling66_Output_0[1] () producer:  ['Pooling66'] \tconsumer:  ['Convolution110']\n",
      "tensor: Parameter193_reshape1_shape[2] (2) producer:  ['_param'] \tconsumer:  ['Times212_reshape1']\n",
      "tensor: Convolution110_Output_0[1] () producer:  ['Convolution110'] \tconsumer:  ['Plus112']\n",
      "tensor: ReLU32_Output_0[1] () producer:  ['ReLU32'] \tconsumer:  ['Pooling66']\n",
      "tensor: Pooling160_Output_0_reshape0[1] () producer:  ['Times212_reshape0'] \tconsumer:  ['Times212']\n",
      "tensor: Pooling160_Output_0[1] () producer:  ['Pooling160'] \tconsumer:  ['Times212_reshape0']\n",
      "tensor: Parameter193_reshape1[1] () producer:  ['Times212_reshape1'] \tconsumer:  ['Times212']\n",
      "tensor: Parameter87[3200] (16,8,5,5) producer:  ['_param'] \tconsumer:  ['Convolution110']\n",
      "tensor: Convolution28_Output_0[1] () producer:  ['Convolution28'] \tconsumer:  ['Plus30']\n",
      "tensor: Parameter193[2560] (16,4,4,10) producer:  ['_param'] \tconsumer:  ['Times212_reshape1']\n",
      "tensor: Plus112_Output_0[1] () producer:  ['Plus112'] \tconsumer:  ['ReLU114']\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "this print all tensors in graph, with tensors' producer nodes and consumer nodes\n",
    "- the tensor `Input3` with empty producer is the graph's input tensor\n",
    "- the tensor `Plus214_Output_0` with empty consumer is the graph's output tensor\n",
    "\n",
    "# print(graph)\n",
    "this will print all nodes in graph with (node.name, node.op_type)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "print(graph)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "name:Times212_reshape1\top_type:Reshape\n",
      "name:Convolution28\top_type:Conv\n",
      "name:Plus30\top_type:Add\n",
      "name:ReLU32\top_type:Relu\n",
      "name:Pooling66\top_type:MaxPool\n",
      "name:Convolution110\top_type:Conv\n",
      "name:Plus112\top_type:Add\n",
      "name:ReLU114\top_type:Relu\n",
      "name:Pooling160\top_type:MaxPool\n",
      "name:Times212_reshape0\top_type:Reshape\n",
      "name:Times212\top_type:MatMul\n",
      "name:Plus214\top_type:Add\n",
      "\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# run pass\n",
    "this will fuse conv_add, and remove reshape node"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "graph = pass_m.run_all_pass(graph)\r\n",
    "print(graph)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "name:Convolution28\top_type:Conv_Add_fused\n",
      "name:ReLU32\top_type:Relu\n",
      "name:Pooling66\top_type:MaxPool\n",
      "name:Convolution110\top_type:Conv_Add_fused\n",
      "name:ReLU114\top_type:Relu\n",
      "name:Pooling160\top_type:MaxPool\n",
      "name:Times212\top_type:MatMul_Add_fused\n",
      "\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# dump graph ir"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "graph.dump_ir()"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "%0 = Input3\n",
      "%1 = Parameter5\n",
      "%2 = Parameter6\n",
      "%3 = Conv_Add_fused(%0,%1,%2)\n",
      "%4 = Relu(%3)\n",
      "%5 = MaxPool(%4)\n",
      "%6 = Parameter87\n",
      "%7 = Parameter88\n",
      "%8 = Conv_Add_fused(%5,%6,%7)\n",
      "%9 = Relu(%8)\n",
      "%10 = MaxPool(%9)\n",
      "%11 = Parameter193_reshape1\n",
      "%12 = Parameter194\n",
      "%13 = MatMul_Add_fused(%10,%11,%12)\n",
      "\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# inference, get all tensors"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "graph.infershape()\r\n",
    "tensor_map = graph.get_all_tensors(debug=1)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "0 tensor: Input3[784] (1,1,28,28)\n",
      "1 tensor: Parameter5[200] (8,1,5,5)\n",
      "2 tensor: Parameter6[8] (8)\n",
      "3 tensor: Plus30_Output_0[6272] (1,8,28,28)\n",
      "4 tensor: ReLU32_Output_0[6272] (1,8,28,28)\n",
      "5 tensor: Pooling66_Output_0[1568] (1,8,14,14)\n",
      "6 tensor: Parameter87[3200] (16,8,5,5)\n",
      "7 tensor: Parameter88[16] (16)\n",
      "8 tensor: Plus112_Output_0[3136] (1,16,14,14)\n",
      "9 tensor: ReLU114_Output_0[3136] (1,16,14,14)\n",
      "10 tensor: Pooling160_Output_0[256] (1,256)\n",
      "11 tensor: Parameter193_reshape1[2560] (256,10)\n",
      "12 tensor: Parameter194[10] (10)\n",
      "13 tensor: Plus214_Output_0[10] (1,10)\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# mnist demo (with preprocess, postprocess)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "def preprocess(img_path, inp_dim):\r\n",
    "    image = cv2.imread(img_path)\r\n",
    "    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\r\n",
    "    gray = cv2.resize(gray,(inp_dim[2], inp_dim[3])).astype(np.float32) / 255\r\n",
    "    input_data = np.reshape(gray, inp_dim)\r\n",
    "    return input_data\r\n",
    "\r\n",
    "def postprocess(out_data):\r\n",
    "    prediction = int(np.argmax(np.array(out_data).squeeze(), axis=0))\r\n",
    "    print(\"prediction: \", prediction)\r\n",
    "\r\n",
    "def mnist_demo(image_path):\r\n",
    "    inp_data = preprocess(image_path, inp_dim)\r\n",
    "    graph.feed_input_data(inp_data, inp_dim)\r\n",
    "    graph.run()\r\n",
    "    out_data = graph.get_output_data()\r\n",
    "    postprocess(out_data)\r\n",
    "\r\n",
    "mnist_demo(\"data/0.jpg\")\r\n",
    "mnist_demo(\"data/3.jpg\")\r\n",
    "mnist_demo(\"data/6.jpg\")"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "prediction:  0\n",
      "prediction:  3\n",
      "prediction:  6\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# dump output data of each node\n",
    "the dumped data can be found in `data/` directory"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "source": [
    "inp_data = preprocess(\"data/6.jpg\", inp_dim)\r\n",
    "graph.feed_input_data(inp_data, inp_dim)\r\n",
    "graph.run(debug=1)\r\n",
    "out_data = graph.get_output_data()\r\n"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "dumping output tensor of each nodes\n",
      ".... data/Convolution28_1_8_28_28.bin\n",
      ".... data/ReLU32_1_8_28_28.bin\n",
      ".... data/Pooling66_1_8_14_14.bin\n",
      ".... data/Convolution110_1_16_14_14.bin\n",
      ".... data/ReLU114_1_16_14_14.bin\n",
      ".... data/Pooling160_1_256.bin\n",
      ".... data/Times212_1_10.bin\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# check output data"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "source": [
    "print(out_data)\r\n",
    "postprocess(out_data)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "[[ 2.79700494e+00 -1.24417000e+01  2.06825852e-01 -3.55096745e+00\n",
      "   1.44057125e-02  5.13820505e+00  1.75181866e+01 -1.69534550e+01\n",
      "   2.51717901e+00 -5.37660265e+00]]\n",
      "prediction:  6\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# save input data\r\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "inp_data = preprocess(\"data/6.jpg\", inp_dim)\r\n",
    "import array\r\n",
    "data=array.array('f',inp_data.flatten())\r\n",
    "with open('data/input_6.bin','wb') as fp:\r\n",
    "    data.tofile(fp)"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.7.3",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.3 64-bit ('base': conda)"
  },
  "interpreter": {
   "hash": "fde18881118dd2b73aface98194fbeb10982cc327a13ea8b1d70ea74b1f8a165"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}